{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fd51733c",
   "metadata": {},
   "source": [
    "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
    "\n",
    "# snnTorch - Spiking Autoencoder (SAE) using Convolutional Spiking Neural Networks\n",
    "## Tutorial by Alon Loeffler (www.alonloeffler.com)\n",
    "\n",
    "*This tutorial is adapted from my original article published on <a href=\"https://medium.com/@alon.loeffler/implementing-a-spiking-autoencoder-sae-and-varational-autoencoder-svae-in-snntorch-8bf267e5b4a0\">Medium.com\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_sae.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6e00599",
   "metadata": {},
   "source": [
    "For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)\n",
    "\n",
    "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
    "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad8d3c7a",
   "metadata": {},
   "source": [
    "In this tutorial, you will learn how to use snnTorch to:\n",
    "* Create a spiking Autoencoder\n",
    "* Reconstruct MNIST images\n",
    "\n",
    "If running in Google Colab:\n",
    "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c81724f7",
   "metadata": {},
   "source": [
    "# 1. Autoencoders\n",
    "\n",
    "An autoencoder is a neural network that is trained to reconstruct its input data. \n",
    "It consists of two main components: \n",
    "1) An encoder  \n",
    "2) A decoder \n",
    "\n",
    "The encoder takes in input data (e.g. an image) and maps it to a lower-dimensional latent space. For example an encoder might take in as input a 28 x 28 pixel MNIST image (784 pixels total), and extract the important features from the image while compressing it to a smaller dimensionality (e.g. 32 features). This compressed representation of the image is called the *latent representation*. \n",
    "\n",
    "The decoder maps the latent representation back to the original input space (i.e. from 32 features back to 784 pixels), and tries to reconstruct the original image from a small number of key features. \n",
    "\n",
    "<center>\n",
    "   <figure>\n",
    "<img src='https://miro.medium.com/max/828/0*dHxZ5LCq5kEPltWH.webp' width=\"600\">\n",
    "<figcaption> Example of a simple Autoencoder where x is the input data, z is the encoded latent space, and x' is the reconstructed inputs once z is decoded (source: Wikipedia). </figcaption>\n",
    "            </figure>\n",
    "</center>\n",
    "\n",
    "\n",
    "\n",
    "The goal of the autoencoder is to minimize the reconstruction error between the input data and the output of the decoder. \n",
    "\n",
    "This is achieved by training the model to minimize the reconstruction loss, which is typically defined as the mean squared error (MSE) between the input and the reconstructed output.\n",
    "\n",
    "<br>\n",
    "\n",
    "<center>\n",
    "   <figure>\n",
    "<img src='https://miro.medium.com/max/640/1*kjfms6RCnHVMLRSq75AD0Q.webp'>\n",
    "<figcaption> MSE loss equation. Here, $y$ would represent the original image (y true) and $\\hat{y}$ would represent the reconstructed outputs (y pred) (source: Towards Data Science). </figcaption>\n",
    "            </figure>\n",
    "</center>\n",
    "\n",
    "<br>\n",
    "\n",
    "Autoencoders are excellent tools for reducing noise in data by finding only the important parts of the data, and discarding everything else during the reconstruction process. This is effectively a dimensionality reduction tool.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e944a4e0",
   "metadata": {},
   "source": [
    "In this tutorial (similar to <a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_1_spikegen.ipynb\">tutorial 1</a>), we will assume we have some non-spiking input data (i.e., the MNIST dataset) and that we want to encode it and reconstruct it. So let's get started! "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d92a1ab",
   "metadata": {},
   "source": [
    "## 2. Setting Up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e87fc8e3",
   "metadata": {},
   "source": [
    "### 2.1 Install/Import packages and set up environment "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba526215",
   "metadata": {},
   "source": [
    "To start, we need to install snnTorch and its dependencies (note this tutorial assumes you have pytorch and torchvision already installed - these come preinstalled in Colab). \n",
    "You can do this by running the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd47a82",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install snntorch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14b50acc",
   "metadata": {},
   "source": [
    "Next, let’s import the necessary modules and set up the SAE model. \n",
    "\n",
    "We can use pyTorch to define the encoder and decoder networks, and snnTorch to convert the neurons in the networks into leaky integrate and fire (LIF) neurons, which read in and output spikes. \n",
    "\n",
    "We will be using convolutional neural networks (CNN), covered in <a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_6_CNN.ipynb\">tutorial 6</a>, for the basis of our encoder and decoder.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "d1efbb16",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import utils as utls\n",
    "\n",
    "import snntorch as snn\n",
    "from snntorch import utils\n",
    "from snntorch import surrogate\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "#Define the SAE model:\n",
    "class SAE(nn.Module):\n",
    "    def __init__(self,latent_dim):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim #dimensions of the encoded z-space data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f11183d",
   "metadata": {},
   "source": [
    "## 3. Building the Autoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9192abd3",
   "metadata": {},
   "source": [
    "### 3.1 DataLoaders\n",
    "We will be using the MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "b42dcd16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataloader arguments\n",
    "batch_size = 250\n",
    "data_path='/data/mnist'\n",
    "\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "189d3d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a transform\n",
    "input_size = 32 #for the sake of this tutorial, we will be resizing the original MNIST from 28 to 32\n",
    "\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((input_size, input_size)),\n",
    "            transforms.Grayscale(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0,), (1,))])\n",
    "\n",
    "# Load MNIST\n",
    "\n",
    "# Training data\n",
    "train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# Testing data\n",
    "test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transform, download=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d916c07",
   "metadata": {},
   "source": [
    "### 3.2 The Encoder\n",
    "Let's start building the sections of our autoencoder which we slowly combine together to the SAE model we defined above:\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44259b04",
   "metadata": {},
   "source": [
    "\n",
    "First, let's add an encoder with three convolutional layers (`nn.Conv2d`), and one fully-connected linear output layer. \n",
    "\n",
    " - We will use a kernel of size 3, with padding of 1 and stride of 2 for the CNN hyperparameters. \n",
    "\n",
    " - We also add a Batch Norm layer between convolutional layers. Since will be using the neuron membrane potential as outputs from each neuron, normalization will help our training process.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "81befdad",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define the SAE model:\n",
    "class SAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim #dimensions of the encoded z-space data\n",
    "        \n",
    "        # Encoder\n",
    "        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON\n",
    "                            nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2\n",
    "                            nn.BatchNorm2d(64),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3\n",
    "                            nn.BatchNorm2d(128),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output\n",
    "                            nn.Linear(128*4*4, latent_dim), # Fully connected linear layer\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)\n",
    "                            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb658440",
   "metadata": {},
   "source": [
    "### 3.3 The Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ac938bd",
   "metadata": {},
   "source": [
    "Before we write the decoder, there is one more small step required. \n",
    "When decoding the latent z-space data, we need to move from the flattened encoded representation (latent_dim) back to a tensor representation to use in transposed convolution. \n",
    "\n",
    "To do so, we need to run an additional fully-connected linear layer transforming the data back into a tensor of 128 x 4 x 4.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "7a72df7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define the SAE model:\n",
    "class SAE(nn.Module):\n",
    "    def __init__(self,latent_dim):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim #dimensions of the encoded z-space data\n",
    "        \n",
    "        # Encoder\n",
    "        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON\n",
    "                            nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2\n",
    "                            nn.BatchNorm2d(64),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3\n",
    "                            nn.BatchNorm2d(128),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output\n",
    "                            nn.Linear(128*4*4, latent_dim), # Fully connected linear layer\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)\n",
    "                            )\n",
    "\n",
    "        # From latent back to tensor for convolution\n",
    "        self.linearNet= nn.Sequential(nn.Linear(latent_dim,128*4*4),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f872c42",
   "metadata": {},
   "source": [
    "Now we can write the decoder, with three transposed convolutional (`nn.ConvTranspose2d`) layers and one linear output layer. \n",
    "Although we converted the latent data back into tensor form for convolution, we still need to Unflatten it to a tensor of 128 x 4 x 4, as the input to the network is 1 dimensional. This is done using `nn.Unflatten` in the first line of the Decoder.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "36c41f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define the SAE model:\n",
    "class SAE(nn.Module):\n",
    "    def __init__(self,latent_dim):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim #dimensions of the encoded z-space data\n",
    "        \n",
    "        # Encoder\n",
    "        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON\n",
    "                            nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2\n",
    "                            nn.BatchNorm2d(64),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3\n",
    "                            nn.BatchNorm2d(128),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output\n",
    "                            nn.Linear(128*4*4, latent_dim), # Fully connected linear layer\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)\n",
    "                            )\n",
    "\n",
    "        # From latent back to tensor for convolution\n",
    "        self.linearNet = nn.Sequential(nn.Linear(latent_dim,128*4*4),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh))\n",
    "        # Decoder\n",
    "        self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)), #Unflatten data from 1 dim to tensor of 128 x 4 x 4\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                            nn.BatchNorm2d(64),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                            nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained\n",
    "                            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4bbc656",
   "metadata": {},
   "source": [
    "One important thing to note is in the final Leaky layer, our spiking threshold (`thresh`) is set extremely high. This is a neat trick in snnTorch, which allows the neuron membrane in the final layer to continuously be updated, without ever reaching a spiking threshold.\n",
    "\n",
    "The output of each Leaky Neuron will consist of a tensor of spikes (0 or 1) and a tensor of neuron membrane potential (negative or positive real numbers). snnTorch allows us to use either the spikes or membrane potential of each neuron in training. We will be using the membrane potential output from the final layer for the image reconstruction.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5ede94a",
   "metadata": {},
   "source": [
    "### 3.4 Forward Function\n",
    "Finally, let’s write the forward, encode and decode functions, before putting it all together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "d63e3d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(self, x): \n",
    "    utils.reset(self.encoder) #need to reset the hidden states of LIF \n",
    "    utils.reset(self.decoder)\n",
    "    utils.reset(self.linearNet) \n",
    "    \n",
    "    #encode\n",
    "    spk_mem=[];spk_rec=[];encoded_x=[]\n",
    "    for step in range(num_steps): #for t in time\n",
    "        spk_x,mem_x=self.encode(x) #Output spike trains and neuron membrane states\n",
    "        spk_rec.append(spk_x) \n",
    "        spk_mem.append(mem_x)\n",
    "    spk_rec=torch.stack(spk_rec,dim=2) # stack spikes in second tensor dimension\n",
    "    spk_mem=torch.stack(spk_mem,dim=2) # stack membranes in second tensor dimension\n",
    "    \n",
    "    #decode\n",
    "    spk_mem2=[];spk_rec2=[];decoded_x=[]\n",
    "    for step in range(num_steps): #for t in time\n",
    "        x_recon,x_mem_recon=self.decode(spk_rec[...,step]) \n",
    "        spk_rec2.append(x_recon) \n",
    "        spk_mem2.append(x_mem_recon)\n",
    "    spk_rec2=torch.stack(spk_rec2,dim=4)\n",
    "    spk_mem2=torch.stack(spk_mem2,dim=4)  \n",
    "    out = spk_mem2[:,:,:,:,-1] #return the membrane potential of the output neuron at t = -1 (last t)\n",
    "    return out \n",
    "\n",
    "def encode(self,x):\n",
    "    spk_latent_x,mem_latent_x=self.encoder(x) \n",
    "    return spk_latent_x,mem_latent_x\n",
    "\n",
    "def decode(self,x):\n",
    "    spk_x,mem_x = self.latentToConv(x) #convert latent dimension back to total size of features in encoder final layer\n",
    "    spk_x2,mem_x2=self.decoder(spk_x)\n",
    "    return spk_x2,mem_x2   "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb5398f2",
   "metadata": {},
   "source": [
    "There are a couple of key things to notice here:\n",
    "\n",
    "1) At the beginning of each call of our forward function, we need to reset the hidden weights of each LIF neuron. If we do not do this, we will get weird gradient errors from pytorch when we try to backprop. To do so we use `utils.reset`.\n",
    "\n",
    "\n",
    "2) In the forward function, when we call the encode and decode functions, we do so in a loop. This is because we are converting static images into spike trains, as explained previously. Spike trains need a time, t, during which spiking can occur or not occur. Therefore, we encode and decode the original image $t$ (or `num_steps`) times, to create a latent representation, $z$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ebe201d",
   "metadata": {},
   "source": [
    "For example, converting a sample digit 7 from the MNIST dataset into a spike-train with a latent dimension of 32 and t = 50, might look like this:\n",
    "<figure>\n",
    "<center>\n",
    "   <figure>\n",
    "<img src='https://miro.medium.com/max/640/1*GuTXNnAfm-Ilc3Lk27r1Ng.webp'>\n",
    "<figcaption> Spike-Train of sample MNIST digit 7 after encoding. Other instances of 7 will have slightly different spike-trains, and different digits will have even more different spike-trains. </figcaption>\n",
    "            </figure>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2db8162",
   "metadata": {},
   "source": [
    "### 3.5 Putting it all together:\n",
    "Our final, complete SAE class should look like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "7daa238e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        #Encoder\n",
    "        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2),\n",
    "                          nn.BatchNorm2d(32),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.Conv2d(32, 64, 3,padding = 1,stride=2),\n",
    "                          nn.BatchNorm2d(64),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.Conv2d(64, 128, 3,padding = 1,stride=2),\n",
    "                          nn.BatchNorm2d(128),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.Flatten(start_dim = 1, end_dim = 3),\n",
    "                          nn.Linear(2048, latent_dim), #this needs to be the final layer output size (channels * pixels * pixels)\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)\n",
    "                          )\n",
    "       # From latent back to tensor for convolution\n",
    "        self.linearNet= nn.Sequential(nn.Linear(latent_dim,128*4*4),\n",
    "                               snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh))        #Decoder\n",
    "        \n",
    "        self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)), \n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                          nn.BatchNorm2d(64),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                          nn.BatchNorm2d(32),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),\n",
    "                          nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1),\n",
    "                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained\n",
    "                          )\n",
    "        \n",
    "    def forward(self, x): #Dimensions: [Batch,Channels,Width,Length]\n",
    "        utils.reset(self.encoder) #need to reset the hidden states of LIF \n",
    "        utils.reset(self.decoder)\n",
    "        utils.reset(self.linearNet) \n",
    "        \n",
    "        #encode\n",
    "        spk_mem=[];spk_rec=[];encoded_x=[]\n",
    "        for step in range(num_steps): #for t in time\n",
    "            spk_x,mem_x=self.encode(x) #Output spike trains and neuron membrane states\n",
    "            spk_rec.append(spk_x) \n",
    "            spk_mem.append(mem_x)\n",
    "        spk_rec=torch.stack(spk_rec,dim=2)\n",
    "        spk_mem=torch.stack(spk_mem,dim=2) #Dimensions:[Batch,Channels,Width,Length, Time]\n",
    "        \n",
    "        #decode\n",
    "        spk_mem2=[];spk_rec2=[];decoded_x=[]\n",
    "        for step in range(num_steps): #for t in time\n",
    "            x_recon,x_mem_recon=self.decode(spk_rec[...,step]) \n",
    "            spk_rec2.append(x_recon) \n",
    "            spk_mem2.append(x_mem_recon)\n",
    "        spk_rec2=torch.stack(spk_rec2,dim=4)\n",
    "        spk_mem2=torch.stack(spk_mem2,dim=4)#Dimensions:[Batch,Channels,Width,Length, Time]  \n",
    "        out = spk_mem2[:,:,:,:,-1] #return the membrane potential of the output neuron at t = -1 (last t)\n",
    "        return out #Dimensions:[Batch,Channels,Width,Length]\n",
    "\n",
    "    def encode(self,x):\n",
    "        spk_latent_x,mem_latent_x=self.encoder(x) \n",
    "        return spk_latent_x,mem_latent_x\n",
    "\n",
    "    def decode(self,x):\n",
    "        spk_x,mem_x = self.linearNet(x) #convert latent dimension back to total size of features in encoder final layer\n",
    "        spk_x2,mem_x2=self.decoder(spk_x)\n",
    "        return spk_x2,mem_x2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d8ac35",
   "metadata": {},
   "source": [
    "## 4. Training and Testing\n",
    "Finally, we can move on to training our SAE, and testing its usefulness. We have already loaded the MNIST dataset, and split it into training and testing classes.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7946babc",
   "metadata": {},
   "source": [
    "### 4.1 Training Function\n",
    "\n",
    "We define our training function, which takes in the network model, training dataset, optimizer and epoch number as inputs, and returns the loss value after running all batches of the current epoch. \n",
    "\n",
    "As discussed at the beginning, we will be using MSE loss to compare the reconstructed image (`x_recon`) with the original image (`real_img`)\n",
    "\n",
    "As always, to set up our gradients for backprop we use `opti.zero_grad()`, and then call `loss_val.backward()` and `opti.step()` to perform backprop. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "546dc6af",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Training \n",
    "def train(network, trainloader, opti, epoch): \n",
    "    \n",
    "    network=network.train()\n",
    "    train_loss_hist=[]\n",
    "    for batch_idx, (real_img, labels) in enumerate(trainloader):   \n",
    "        opti.zero_grad()\n",
    "        real_img = real_img.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        #Pass data into network, and return reconstructed image from Membrane Potential at t = -1\n",
    "        x_recon = network(real_img) #Dimensions passed in: [Batch_size,Input_size,Image_Width,Image_Length] \n",
    "        \n",
    "        #Calculate loss        \n",
    "        loss_val = F.mse_loss(x_recon, real_img)\n",
    "                \n",
    "        print(f'Train[{epoch}/{max_epoch}][{batch_idx}/{len(trainloader)}] Loss: {loss_val.item()}')\n",
    "\n",
    "        loss_val.backward()\n",
    "        opti.step()\n",
    "\n",
    "        #Save reconstructed images every at the end of the epoch\n",
    "        if batch_idx == len(trainloader)-1:\n",
    "            # NOTE: you need to create training/ and testing/ folders in your chosen path\n",
    "            utls.save_image((real_img+1)/2, f'figures/training/epoch{epoch}_finalbatch_inputs.png') \n",
    "            utls.save_image((x_recon+1)/2, f'figures/training/epoch{epoch}_finalbatch_recon.png')\n",
    "    return loss_val\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ba0b79b",
   "metadata": {},
   "source": [
    "### 4.2 Testing Function\n",
    "The testing function is nearly identifcal to the training function, except we do not backpropagate, therefore no gradients are required and we use `torch.no_grad()` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "1016eafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Testing \n",
    "def test(network, testloader, opti, epoch):\n",
    "    network=network.eval()\n",
    "    test_loss_hist=[]\n",
    "    with torch.no_grad(): #no gradient this time\n",
    "        for batch_idx, (real_img, labels) in enumerate(testloader):   \n",
    "            real_img = real_img.to(device)#\n",
    "            labels = labels.to(device)\n",
    "            x_recon = network(real_img)\n",
    "\n",
    "            loss_val = F.mse_loss(x_recon, real_img)\n",
    "\n",
    "            print(f'Test[{epoch}/{max_epoch}][{batch_idx}/{len(testloader)}]  Loss: {loss_val.item()}')#, RECONS: {recons_meter.avg}, DISTANCE: {dist_meter.avg}')\n",
    "                \n",
    "            if batch_idx == len(testloader)-1:\n",
    "                utls.save_image((real_img+1)/2, f'figures/testing/epoch{epoch}_finalbatch_inputs.png')\n",
    "                utls.save_image((x_recon+1)/2, f'figures/testing/epoch{epoch}_finalbatch_recons.png')\n",
    "    return loss_val"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "698354f1",
   "metadata": {},
   "source": [
    "There are a couple of ways to calculate loss with spiking neural networks. Here, we are simply taking the membrane potential of the final fully-connected layer of neurons at the last time step ($t = 5$). \n",
    "\n",
    "Therefore, we only need to compare each original image with its corresponding decoded, reconstructed image once per epoch. We can also return the membrane potentials at each time step, and create t different versions of the reconstructed image, and then compare each of them with the original image and take the average loss. For those of you interested in this, you can replace the loss function above with something like this:\n",
    "\n",
    "(*note this will fail to run as we have not defined any of the variables yet, it is just here for illustrative purposes*)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "8384ee4e",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'x_recon' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[72], line 4\u001b[0m\n\u001b[0;32m      2\u001b[0m loss_val \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros((\u001b[38;5;241m1\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mdtype, device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_steps):\n\u001b[1;32m----> 4\u001b[0m     loss_val \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(\u001b[43mx_recon\u001b[49m, real_img)\n\u001b[0;32m      5\u001b[0m train_loss_hist\u001b[38;5;241m.\u001b[39mappend(loss_val\u001b[38;5;241m.\u001b[39mitem())\n\u001b[0;32m      6\u001b[0m avg_loss\u001b[38;5;241m=\u001b[39mloss_val\u001b[38;5;241m/\u001b[39mnum_steps\n",
      "\u001b[1;31mNameError\u001b[0m: name 'x_recon' is not defined"
     ]
    }
   ],
   "source": [
    "train_loss_hist=[]\n",
    "loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
    "for step in range(num_steps):\n",
    "    loss_val += F.mse_loss(x_recon, real_img)\n",
    "train_loss_hist.append(loss_val.item())\n",
    "avg_loss=loss_val/num_steps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b6f66e9",
   "metadata": {},
   "source": [
    "## 5. Conclusion: Running the SAE\n",
    "Now, finally, we can run our SAE model. Let’s define some parameters, and run training and testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "484c4d75",
   "metadata": {},
   "source": [
    "Let's create directories where we can save our original and reconstructed images for training and testing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "b69ecd4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create training/ and testing/ folders in your chosen path\n",
    "if not os.path.isdir('figures/training'):\n",
    "    os.makedirs('figures/training')\n",
    "    \n",
    "if not os.path.isdir('figures/testing'):\n",
    "    os.makedirs('figures/testing')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "76c46ac3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train[0/10][0/240] Loss: 0.10109379142522812\n",
      "Train[0/10][1/240] Loss: 0.10465191304683685\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# dataloader arguments\n",
    "batch_size = 250\n",
    "input_size = 32 #resize of mnist data (optional)\n",
    "\n",
    "#setup GPU\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "# neuron and simulation parameters\n",
    "spike_grad = surrogate.atan(alpha=2.0)# alternate surrogate gradient fast_sigmoid(slope=25) \n",
    "beta = 0.5 #decay rate of neurons \n",
    "num_steps=5\n",
    "latent_dim = 32 #dimension of latent layer (how compressed we want the information)\n",
    "thresh=1#spiking threshold (lower = more spikes are let through)\n",
    "epochs=10 \n",
    "max_epoch=epochs\n",
    "\n",
    "#Define Network and optimizer\n",
    "net=SAE()\n",
    "net = net.to(device)\n",
    "\n",
    "optimizer = torch.optim.AdamW(net.parameters(), \n",
    "                            lr=0.0001,\n",
    "                            betas=(0.9, 0.999), \n",
    "                            weight_decay=0.001)\n",
    "\n",
    "#Run training and testing        \n",
    "for e in range(epochs): \n",
    "    train_loss = train(net, train_loader, optimizer, e)\n",
    "    test_loss = test(net,test_loader,optimizer,e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f50d374",
   "metadata": {},
   "source": [
    "After only 10 epochs, our training and testing reconstructed losses should be around 0.05, and our reconstructed images should look something like this:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40201f00",
   "metadata": {},
   "source": [
    "<img src=\"https://miro.medium.com/max/828/1*-wTHiusUhqn-zabYTpkdEw.webp\" width = 500>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a419d377",
   "metadata": {},
   "source": [
    "Yes, the reconstructed images are a bit blurry, and the loss isn’t perfect, but from only 10 epochs, and only using the final membrane potential at $t = 5$ for our reconstructed loss, it’s a pretty decent start!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aab0eb8",
   "metadata": {},
   "source": [
    "Try increasing the number of epochs, or playing around with `thresh`, `num_steps` and `batch_size` to see if you can get better loss!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
